#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
import multiprocessing
import os
import signal
import subprocess
import time
from contextlib import suppress
from subprocess import CalledProcessError
from time import sleep
from unittest import mock

import psutil
import pytest

from airflow.exceptions import AirflowException
from airflow.utils import process_utils
from airflow.utils.process_utils import (
    check_if_pidfile_process_is_running,
    execute_in_subprocess,
    execute_in_subprocess_with_kwargs,
    set_new_process_group,
)


class TestReapProcessGroup:
    @staticmethod
    def _ignores_sigterm(child_pid, child_setup_done):
        def signal_handler(unused_signum, unused_frame):
            pass

        signal.signal(signal.SIGTERM, signal_handler)
        child_pid.value = os.getpid()
        child_setup_done.release()
        while True:
            time.sleep(1)

    @staticmethod
    def _parent_of_ignores_sigterm(parent_pid, child_pid, setup_done):
        def signal_handler(unused_signum, unused_frame):
            pass

        os.setsid()
        signal.signal(signal.SIGTERM, signal_handler)
        child_setup_done = multiprocessing.Semaphore(0)
        child = multiprocessing.Process(
            target=TestReapProcessGroup._ignores_sigterm, args=[child_pid, child_setup_done]
        )
        child.start()
        child_setup_done.acquire(timeout=5.0)
        parent_pid.value = os.getpid()
        setup_done.release()
        while True:
            time.sleep(1)

    def test_reap_process_group(self):
        """
        Spin up a process that can't be killed by SIGTERM and make sure
        it gets killed anyway.
        """
        parent_setup_done = multiprocessing.Semaphore(0)
        parent_pid = multiprocessing.Value("i", 0)
        child_pid = multiprocessing.Value("i", 0)
        args = [parent_pid, child_pid, parent_setup_done]
        parent = multiprocessing.Process(target=TestReapProcessGroup._parent_of_ignores_sigterm, args=args)
        try:
            parent.start()
            assert parent_setup_done.acquire(timeout=5.0)
            assert psutil.pid_exists(parent_pid.value)
            assert psutil.pid_exists(child_pid.value)

            process_utils.reap_process_group(parent_pid.value, logging.getLogger(), timeout=1)

            assert not psutil.pid_exists(parent_pid.value)
            assert not psutil.pid_exists(child_pid.value)
        finally:
            try:
                os.kill(parent_pid.value, signal.SIGKILL)  # terminate doesn't work here
                os.kill(child_pid.value, signal.SIGKILL)  # terminate doesn't work here
            except OSError:
                pass


@pytest.mark.db_test
class TestExecuteInSubProcess:
    def test_should_print_all_messages1(self, caplog):
        execute_in_subprocess(["bash", "-c", "echo CAT; echo KITTY;"])
        assert caplog.messages == [
            "Executing cmd: bash -c 'echo CAT; echo KITTY;'",
            "Output:",
            "CAT",
            "KITTY",
        ]

    def test_should_print_all_messages_from_cwd(self, caplog, tmp_path):
        execute_in_subprocess(["bash", "-c", "echo CAT; pwd; echo KITTY;"], cwd=str(tmp_path))
        assert [
            "Executing cmd: bash -c 'echo CAT; pwd; echo KITTY;'",
            "Output:",
            "CAT",
            str(tmp_path),
            "KITTY",
        ] == caplog.messages

    def test_using_env_works(self, caplog):
        execute_in_subprocess(["bash", "-c", 'echo "My value is ${VALUE}"'], env=dict(VALUE="1"))
        assert "My value is 1" in caplog.text

    def test_should_raise_exception(self):
        with pytest.raises(CalledProcessError):
            process_utils.execute_in_subprocess(["bash", "-c", "exit 1"])

    def test_using_env_as_kwarg_works(self, caplog):
        execute_in_subprocess_with_kwargs(["bash", "-c", 'echo "My value is ${VALUE}"'], env=dict(VALUE="1"))
        assert "My value is 1" in caplog.text


def my_sleep_subprocess():
    sleep(100)


def my_sleep_subprocess_with_signals():
    signal.signal(signal.SIGINT, lambda signum, frame: None)
    signal.signal(signal.SIGTERM, lambda signum, frame: None)
    sleep(100)


@pytest.mark.db_test
class TestKillChildProcessesByPids:
    def test_should_kill_process(self):
        before_num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")

        process = multiprocessing.Process(target=my_sleep_subprocess, args=())
        process.start()
        sleep(0)

        num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
        assert before_num_process + 1 == num_process

        process_utils.kill_child_processes_by_pids([process.pid])

        num_process = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().count("\n")
        assert before_num_process == num_process

    def test_should_force_kill_process(self, caplog):
        process = multiprocessing.Process(target=my_sleep_subprocess_with_signals, args=())
        process.start()
        sleep(0)

        all_processes = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().splitlines()
        assert str(process.pid) in (x.strip() for x in all_processes)

        with caplog.at_level(logging.INFO, logger=process_utils.log.name):
            caplog.clear()
            process_utils.kill_child_processes_by_pids([process.pid], timeout=0)
            assert f"Killing child PID: {process.pid}" in caplog.messages
        sleep(0)
        all_processes = subprocess.check_output(["ps", "-ax", "-o", "pid="]).decode().splitlines()
        assert str(process.pid) not in (x.strip() for x in all_processes)


class TestPatchEnviron:
    def test_should_update_variable_and_restore_state_when_exit(self):
        with mock.patch.dict("os.environ", {"TEST_NOT_EXISTS": "BEFORE", "TEST_EXISTS": "BEFORE"}):
            del os.environ["TEST_NOT_EXISTS"]

            assert os.environ["TEST_EXISTS"] == "BEFORE"
            assert "TEST_NOT_EXISTS" not in os.environ

            with process_utils.patch_environ({"TEST_NOT_EXISTS": "AFTER", "TEST_EXISTS": "AFTER"}):
                assert os.environ["TEST_NOT_EXISTS"] == "AFTER"
                assert os.environ["TEST_EXISTS"] == "AFTER"

            assert os.environ["TEST_EXISTS"] == "BEFORE"
            assert "TEST_NOT_EXISTS" not in os.environ

    def test_should_restore_state_when_exception(self):
        with mock.patch.dict("os.environ", {"TEST_NOT_EXISTS": "BEFORE", "TEST_EXISTS": "BEFORE"}):
            del os.environ["TEST_NOT_EXISTS"]

            assert os.environ["TEST_EXISTS"] == "BEFORE"
            assert "TEST_NOT_EXISTS" not in os.environ

            with suppress(AirflowException):
                with process_utils.patch_environ({"TEST_NOT_EXISTS": "AFTER", "TEST_EXISTS": "AFTER"}):
                    assert os.environ["TEST_NOT_EXISTS"] == "AFTER"
                    assert os.environ["TEST_EXISTS"] == "AFTER"
                    raise AirflowException("Unknown exception")

            assert os.environ["TEST_EXISTS"] == "BEFORE"
            assert "TEST_NOT_EXISTS" not in os.environ


class TestCheckIfPidfileProcessIsRunning:
    def test_ok_if_no_file(self):
        check_if_pidfile_process_is_running("some/pid/file", process_name="test")

    def test_remove_if_no_process(self, tmp_path):
        path = tmp_path / "testfile"
        # limit pid as max of int32, otherwise this test could fail on some platform
        path.write_text(f"{2**31 - 1}")
        check_if_pidfile_process_is_running(os.fspath(path), process_name="test")
        # Assert file is deleted
        assert not path.exists()

    def test_raise_error_if_process_is_running(self, tmp_path):
        path = tmp_path / "testfile"
        pid = os.getpid()
        path.write_text(f"{pid}")
        with pytest.raises(AirflowException, match="is already running under PID"):
            check_if_pidfile_process_is_running(os.fspath(path), process_name="test")


class TestSetNewProcessGroup:
    @mock.patch("os.setpgid")
    def test_not_session_leader(self, mock_set_pid):
        pid = os.getpid()
        with mock.patch("os.getsid", autospec=True) as mock_get_sid:
            mock_get_sid.return_value = pid + 1
            set_new_process_group()
            assert mock_set_pid.call_count == 1

    @mock.patch("os.setpgid")
    def test_session_leader(self, mock_set_pid):
        pid = os.getpid()
        with mock.patch("os.getsid", autospec=True) as mock_get_sid:
            mock_get_sid.return_value = pid
            set_new_process_group()
            assert mock_set_pid.call_count == 0
